import gym
from gym import spaces
import highway_env
import numpy as np
import pytest
import cv2
from matplotlib import pyplot as plt
from rl_agents.agents.deep_q_network.pytorch import DQNAgent

frameSize = (640, 280)
out = cv2.VideoWriter('video-test-1.avi',cv2.VideoWriter_fourcc(*'DIVX'), 4, frameSize)


torch = pytest.importorskip("torch")

env = gym.make("highway-fast-v0")
#action_space and observation_space are built by constructor with default setttings

#Modify settings
print(env.configure({
    "lanes_count": 10,
    "screen_width": 640,
    "screen_height": 280,
    "action": {
        "type": "DiscreteAction",
        "steering_range": [-np.pi / 3, np.pi / 3],
        "longitudinal": True,
        "lateral": True,
        "dynamical": False
    },
    "simulation_frequency": 8
    
}))
#Rebuild the action_space and observation_space
env.define_spaces()
print(env.action_space, env.config["lanes_count"])

agent = DQNAgent(env, None)
state, info = env.reset()
print(state)

n = 2 * agent.config['batch_size']
print(n)
updown = np.random.randint(2, size=200)
for f in range(n):
    #First array value is acceleration and second is steering
    action = agent.act(state)
    print(action)
    obs, reward, done, truncated, info = env.step(action)
    agent.record(state, action, reward, obs, done, info)
    env.render()
    cur_frame = env.render(mode="rgb_array")
    out.write(cur_frame)
    if f > 100:
        break
    if done:
        state, info = env.reset()
    else:
        state = obs
out.release()

print('done')
    #assert (len(agent.memory) == n or
            #len(agent.memory) == agent.config['memory_capacity'])